# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.

Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""

import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor

import torch

from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

logger = init_logger(__name__)


class BrokenPipeException(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)


Metadata = dict[str, torch.Tensor | None]


class PyNcclPipe(KVPipeBase):
    METADATA_LENGTH = 16
    MAX_TENSOR_DIMENSIONS = 14
    METADATA_DTYPE = torch.int64

    def __init__(
        self,
        local_rank: int,
        config: KVTransferConfig,
        device: str | None = None,
        port_offset: int = 0,
    ):
        self.config = config
        self.local_rank = local_rank
        self.kv_rank = self.config.kv_rank
        assert self.kv_rank is not None
        self.kv_parallel_size = self.config.kv_parallel_size
        if device is None:
            self.device = self._select_device(self.config.kv_buffer_device)
        else:
            self.device = self._select_device(device)

        # build distributed connection and send/recv implementation
        store_timeout = self.config.get_from_extra_config("store_timeout", 300)
        self.group = StatelessProcessGroup.create(
            host=self.config.kv_ip,
            port=self.config.kv_port + port_offset,
            rank=self.kv_rank,
            world_size=self.kv_parallel_size,
            store_timeout=store_timeout,
        )
        # add a barrier to make sure the connection is initiated properly
        self.group.barrier()
        impl = self._get_device_send_recv_impl(self.group)
        self.device_send_func, self.device_recv_func = impl
        # set target rank
        self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
        self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size

        # transportation-related variables
        self.transport_thread: ThreadPoolExecutor | None = None
        self.buffer_size = 0
        self.buffer_size_lock = threading.Lock()
        self.buffer_size_thresh = self.config.kv_buffer_size

    def _get_device_send_recv_impl(
        self, group: StatelessProcessGroup
    ) -> tuple[
        Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
    ]:
        send: Callable[[torch.Tensor, int], None]
        recv: Callable[[torch.Tensor, int], None]
        if self.device.type == "cuda":
            # use PyNCCL for send / recv
            comm = PyNcclCommunicator(group, device=self.local_rank)
            comm.disabled = False
            send, recv = comm.send, comm.recv  # type: ignore
        else:
            # This send / recv implementation here is NOT intended to transfer
            # KV caches (and should NOT be repurposed to transfer KV caches).
            # Currently it is only used to transmit control-plane messages
            # for PyNcclBuffer.
            send = group.send_obj

            def my_recv(x, src):
                x[...] = group.recv_obj(src)

            recv = my_recv

        return send, recv

    def _select_device(self, device: str):
        logger.info("Selecting device: %s", device)
        if device == "cuda":
            return torch.device(f"cuda:{self.local_rank}")
        else:
            return torch.device("cpu")

    def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
        """
        Create the metadata as a dictionary based on the input tensor.

        Args:
            tensor: The input tensor or None if no tensor is provided.

        Returns:
            metadata: A dictionary with the following keys:
                - "dtype": The data type of the tensor or None.
                - "shape": The shape of the tensor or None.
        """
        if tensor is None:
            return {"dtype": None, "shape": None}
        else:
            return {"dtype": tensor.dtype, "shape": tensor.shape}

    def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
        """
        Create a buffer to receive the tensor based on the provided metadata.

        Args:
            metadata: A dictionary with keys "dtype" and "shape",
                describing the tensor's data type and shape.

        Returns:
            buffer: A tensor of the specified type and shape,
                allocated on `self.device`.
        """
        return torch.empty(
            metadata["shape"], dtype=metadata["dtype"], device=self.device
        )

    def _send_metadata(self, metadata: Metadata):
        """
        Send the metadata dictionary to the target rank.

        Args:
            metadata: A dictionary with keys "dtype" and "shape".
        """
        self.group.send_obj(metadata, self.target_rank_for_send)

    def _recv_metadata(self) -> Metadata:
        """
        Receive the metadata dictionary from the target rank.

        Returns:
            metadata: A dictionary with keys "dtype" and "shape"
                describing the tensor.
        """
        return self.group.recv_obj(self.target_rank_for_recv)

    def _send_impl(self, tensor: torch.Tensor | None) -> None:
        """
        The actual implementation of sending the tensor and its metadata to the
        target rank.

        Args:
            tensor: The input tensor to be sent, or `None` if no tensor is
                being sent.
        """
        metadata = self._make_metadata(tensor)
        self._send_metadata(metadata)
        if tensor is not None:
            self.device_send_func(tensor.to(self.device), self.target_rank_for_send)

    def _recv_impl(self) -> torch.Tensor | None:
        """
        The actual implementation of receiving a tensor and its metadata from
        the target rank.

        Returns:
            buffer: The received tensor, or `None` if no tensor is received.
        """
        metadata = self._recv_metadata()
        if metadata["dtype"] is None:
            return None
        buffer = self._prepare_recv_buffer(metadata)
        self.device_recv_func(buffer, self.target_rank_for_recv)

        return buffer

    def send_tensor_wrapper(
        self, tensor: torch.Tensor | None, tensor_size: int
    ) -> None:
        """
        Wrapper for _send_impl to handle exceptions and update buffer size.
        """
        try:
            self._send_impl(tensor)

            with self.buffer_size_lock:
                self.buffer_size -= tensor_size
        except Exception as e:
            logger.error(
                "[rank%d]: Exception when trying to send %s, msg: %s",
                torch.distributed.get_rank(),
                str(tensor),
                str(e),
            )
            import traceback

            traceback.print_exc()

    def block_if_full(self):
        """
        Block the current thread if the buffer size is larger than the
        threshold.
        """
        while self.buffer_size > self.buffer_size_thresh:
            logger.debug("KV cache transfer pipe is full. Waiting...")
            time.sleep(0.05)

    def send_tensor(self, tensor: torch.Tensor | None) -> None:
        """
        Sends a tensor and its metadata to the destination rank in a
        non-blocking way.

        Args:
            tensor: The tensor to send, or `None` if no tensor is being sent.
        """
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)

        if tensor is not None:
            tensor_size = tensor.element_size() * tensor.numel()
        else:
            tensor_size = 0

        self.block_if_full()

        with self.buffer_size_lock:
            self.buffer_size += tensor_size

        self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)

    def recv_tensor(self) -> torch.Tensor | None:
        """
        Receives a tensor and its metadata from the source rank. Blocking call.

        Returns:
            The received tensor, or `None` if no tensor is received.
        """
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)

        future = self.transport_thread.submit(self._recv_impl)

        try:
            tensor = future.result()
        except Exception as e:
            logger.error("Encountering exception in KV receiving thread")
            logger.error("%s", e)
            logger.error("My device: %s", self.device)
            import traceback

            traceback.print_exc()
            raise e

        return tensor

    def close(self):
        """
        Close the pipe and release associated resources.
        """
        if hasattr(self, "transport_thread") and self.transport_thread is not None:
            self.transport_thread.shutdown()
